{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Two coupled equilibria: protolysis of ammonia in water\n",
    "In this notebook we will look at how ``ChemPy`` can be used to formulate a system of (non-linear) equations from conservation laws and equilibrium equations. We will look att ammonia since it is a fairly well-known substance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from operator import mul\n",
    "from functools import reduce\n",
    "from itertools import product\n",
    "import chempy\n",
    "from chempy.chemistry import Species, Equilibrium\n",
    "from chempy.equilibria import EqSystem, NumSysLog, NumSysLin\n",
    "import numpy as np\n",
    "import sympy as sp\n",
    "sp.init_printing()\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "sp.__version__, chempy.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, let's define our substances. ChemPy can parse chemical formulae into ``chempy.Substance`` instances with information on charge, composition, molar mass (and pretty-printing):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "substance_names = ['H+', 'OH-', 'NH4+', 'NH3', 'H2O']\n",
    "subst = {n: Species.from_formula(n) for n in substance_names}\n",
    "assert [subst[n].charge for n in substance_names] == [1, -1, 1, 0, 0], \"Charges of substances\"\n",
    "print(u'Composition of %s: %s' % (subst['NH3'].unicode_name, subst['NH3'].composition))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's define some initial concentrations and governing equilibria:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_conc = {'H+': 1e-7, 'OH-': 1e-7, 'NH4+': 1e-7, 'NH3': 1.0, 'H2O': 55.5}\n",
    "x0 = [init_conc[k] for k in substance_names]\n",
    "H2O_c = init_conc['H2O']\n",
    "w_autop = Equilibrium({'H2O': 1}, {'H+': 1, 'OH-': 1}, 10**-14/H2O_c)\n",
    "NH4p_pr = Equilibrium({'NH4+': 1}, {'H+': 1, 'NH3': 1}, 10**-9.26)\n",
    "equilibria = w_autop, NH4p_pr\n",
    "[(k, init_conc[k]) for k in substance_names]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eqsys = EqSystem(equilibria, subst)\n",
    "eqsys"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can solve the non-linear system of equations by calling the ``root`` method (the underlying representation uses [pyneqsys](https://github.com/bjodah/pyneqsys):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x, sol, sane = eqsys.root(init_conc)\n",
    "x, sol['success'], sane"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This system is quite easy to solve for, if we have convergence problems we can try to solve a transformed system. As an example we will use ``NumSysLog``:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logx, logsol, sane = eqsys.root(init_conc, NumSys=(NumSysLog,))\n",
    "logx, logsol['success'], sane"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this case they give essentially the same answer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x - logx"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can create symbolic representations of these systems of equations:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ny = len(substance_names)\n",
    "y = sp.symarray('y', ny)\n",
    "i = sp.symarray('i', ny)\n",
    "K = Kw, Ka = sp.symbols('K_w K_a')\n",
    "w_autop.param = Kw\n",
    "NH4p_pr.param = Ka\n",
    "ss = sp.symarray('s', ny)\n",
    "ms = sp.symarray('m', ny)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "numsys_log = NumSysLog(eqsys, backend=sp)\n",
    "f = numsys_log.f(y, list(i)+list(K))\n",
    "f"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "numsys_lin = NumSysLin(eqsys, backend=sp)\n",
    "numsys_lin.f(y, i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "A, ks = eqsys.stoichs_constants(False, backend=sp)\n",
    "[reduce(mul, [b**e for b, e in zip(y, row)]) for row in A]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyneqsys.symbolic import SymbolicSys\n",
    "subs = list(zip(i, x0)) + [(Kw, 10**-14), (Ka, 10**-9.26)]\n",
    "numf = [_.subs(subs) for _ in f]\n",
    "neqs = SymbolicSys(list(y), numf)\n",
    "neqs.solve([0, 0, 0, 0, 0], solver='scipy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "j = sp.Matrix(1, len(f), lambda _, q: f[q]).jacobian(y)\n",
    "init_conc_j = {'H+': 1e-10, 'OH-': 1e-7, 'NH4+': 1e-7, 'NH3': 1.0, 'H2O': 55.5}\n",
    "xj = eqsys.as_per_substance_array(init_conc_j)\n",
    "jarr = np.array(j.subs(dict(zip(y, xj))).subs({Kw: 1e-14, Ka: 10**-9.26}).subs(\n",
    "            dict(zip(i, xj))))\n",
    "jarr = np.asarray(jarr, dtype=np.float64)\n",
    "np.log10(np.linalg.cond(jarr))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "j.simplify()\n",
    "j"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eqsys.composition_balance_vectors()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "numsys_rref_log = NumSysLog(eqsys, True, True, backend=sp)\n",
    "numsys_rref_log.f(y, list(i)+list(K))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.set_printoptions(4, linewidth=120)\n",
    "scaling = 1e8\n",
    "for rxn in eqsys.rxns:\n",
    "    rxn.param = rxn.param.subs({Kw: 1e-14, Ka: 10**-9.26})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x, res, sane = eqsys.root(init_conc, rref_equil=True, rref_preserv=True)\n",
    "x, res['success'], sane"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x, res, sane = eqsys.root(init_conc, x0=eqsys.as_per_substance_array(\n",
    "        {'H+': 1e-11, 'OH-': 1e-3, 'NH4+': 1e-3, 'NH3': 1.0, 'H2O': 55.5}))\n",
    "res['success'], sane"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x, res, sane = eqsys.root(init_conc, x0=eqsys.as_per_substance_array(\n",
    "        {'H+': 1.7e-11, 'OH-': 3e-2, 'NH4+': 3e-2, 'NH3': 0.97, 'H2O': 55.5}))\n",
    "res['success'], sane"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_conc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nc=60\n",
    "Hp_0 = np.logspace(-3, 0, nc)\n",
    "def plot_rref(**kwargs):\n",
    "    fig, axes = plt.subplots(2, 2, figsize=(16, 6), subplot_kw=dict(xscale='log', yscale='log'))\n",
    "    return [eqsys.roots(init_conc, Hp_0, 'H+', plot_kwargs={'ax': axes.flat[i]}, rref_equil=e,\n",
    "                        rref_preserv=p, **kwargs) for i, (e, p) in enumerate(product(*[[False, True]]*2))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_lin = plot_rref(method='lm')\n",
    "[all(_[2]) for _ in res_lin]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for col_id in range(len(substance_names)):\n",
    "    for i in range(1, 4):\n",
    "        plt.subplot(1, 3, i, xscale='log')\n",
    "        plt.gca().set_yscale('symlog', linthresh=1e-14)\n",
    "        plt.plot(Hp_0, res_lin[0][0][:, col_id] - res_lin[i][0][:, col_id])\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eqsys.plot_errors(res_lin[0][0], init_conc, Hp_0, 'H+')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_conc, eqsys.ns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_log = plot_rref(NumSys=NumSysLog)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eqsys.plot_errors(res_log[0][0], init_conc, Hp_0, 'H+')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_log_lin = plot_rref(NumSys=(NumSysLog, NumSysLin))\n",
    "eqsys.plot_errors(res_log_lin[0][0], init_conc, Hp_0, 'H+')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from chempy.equilibria import NumSysSquare\n",
    "res_log_sq = plot_rref(NumSys=(NumSysLog, NumSysSquare))\n",
    "eqsys.plot_errors(res_log_sq[0][0], init_conc, Hp_0, 'H+')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_sq = plot_rref(NumSys=(NumSysSquare,), method='lm')\n",
    "eqsys.plot_errors(res_sq[0][0], init_conc, Hp_0, 'H+')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x, res, sane = eqsys.root(x0, NumSys=NumSysLog, rref_equil=True, rref_preserv=True)\n",
    "x, res['success'], sane"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x, res, sane = eqsys.root(x, NumSys=NumSysLin, rref_equil=True, rref_preserv=True)\n",
    "x, res['success'], sane"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
